# SPDX-License-Identifier: Apache-2.0
"""Kernel test utils"""

import itertools
import random
import unittest
from collections.abc import Sequence
from numbers import Number
from typing import Any, NamedTuple, Optional, Union

import pytest
import torch
from torch._prims_common import TensorLikeType

from vllm.attention import AttentionBackend, AttentionMetadata, AttentionType
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.platforms.interface import _Backend
from vllm.utils import (STR_BACKEND_ENV_VAR, STR_FLASH_ATTN_VAL,
                        STR_XFORMERS_ATTN_VAL, make_tensor_with_pad)

# For now, disable "test_aot_dispatch_dynamic" since there are some
# bugs related to this test in PyTorch 2.4.
DEFAULT_OPCHECK_TEST_UTILS: tuple[str, ...] = (
    "test_schema",
    "test_autograd_registration",
    "test_faketensor",
)

ALL_OPCHECK_TEST_UTILS: tuple[str, ...] = (
    "test_schema",
    "test_autograd_registration",
    "test_faketensor",
    "test_aot_dispatch_dynamic",
)


class QKVInputs(NamedTuple):
    '''
    Data structure for representing unpacked attention inputs, 
    query/key/values and their sequence lengths.

    Attributes:

        * {query,key,value}: unpacked (batch_size x padded_seq_len x 
                             num_heads x head_size) attention inputs
        * q_seq_lens: query sequence lengths list
        * kv_seq_lens: shared key/value sequence lengths list
    '''

    query: torch.Tensor
    key: torch.Tensor
    value: torch.Tensor
    q_seq_lens: list[int]
    kv_seq_lens: list[int]


class QKVO(NamedTuple):
    '''
    Data structure for representing unpacked attention inputs, 
    alongside unpacked known-correct attention output

    Attributes:

        * qkv: unpacked (batch_size x padded_seq_len x 
                             num_heads x head_size) attention inputs
        * ideal_output: unpacked (batch_size x padded_seq_len x 
                        num_heads x head_size) known-correct attention output
    '''

    qkv: QKVInputs
    ideal_output: torch.Tensor


class PackedQKVInputs(NamedTuple):
    '''
    Data structure for representing packed attention inputs

    Attributes:

        * {query,key,value}: packed (number_of_tokens x num_heads 
                             x head_size) attention inputs
        * q_start_loc_list: list of query start locations within packed tensor
        * kv_start_loc_list: shared list of key/value start locations within
                             packed tensor
        * q_seq_lens: query sequence lengths list
        * kv_seq_lens: shared key/value sequence lengths list
    '''

    query: torch.Tensor
    key: torch.Tensor
    value: torch.Tensor
    q_start_loc_list: Optional[list[int]]
    kv_start_loc_list: Optional[list[int]]
    q_seq_lens: Optional[list[int]]
    kv_seq_lens: Optional[list[int]]


class PackedQKVO(NamedTuple):
    '''
    Data structure for representing packed attention inputs, 
    alongside packed known-correct attention output

    Attributes:

        * packed_qkv: packed (number_of_tokens x num_heads 
                      x head_size) attention inputs
        * ideal_output: packed (number_of_tokens x num_heads 
                        x head_size) known-correct attention output
    '''

    packed_qkv: Optional[PackedQKVInputs]
    ideal_output: torch.Tensor


class KVMemoryMap(NamedTuple):
    '''
    Data structure for encapsulating KV cache memory mapping.

    Attributes:

        * block_tables: KV cache block tables
        * slot_mapping: mapping of sequence offset to physical address
    '''

    block_tables: torch.Tensor
    slot_mapping: torch.Tensor


class PhaseTestParameters(NamedTuple):
    '''
    Data structure for encapsulating the test parameters
    for a given test "phase" (prefill or decode phase) and attention
    scenario (encoder, decoder-self, encoder/decoder-cross)

    Attributes:

        * packed_qkvo: packed (number_of_tokens x num_heads 
                       x head_size) attention inputs & known-correct
                       output
        * kv_mmap: KV cache memory mapping, specific to this test phase &
                   attention scenario
    '''

    packed_qkvo: PackedQKVO
    kv_mmap: Optional[KVMemoryMap]


def maybe_make_int_tensor(
    _list: Optional[list[int]],
    device: Union[torch.device, str],
) -> torch.Tensor:
    '''
    Convert Python int list to a 1D int torch.Tensor on `device`

    Returns:

    * If _list is not None: 1D int torch.Tensor on `device`
    * None otherwise
    '''
    return None if _list is None else torch.tensor(
        _list, dtype=torch.int, device=device)


def maybe_make_long_tensor(
    _list: Optional[list[int]],
    device: Union[torch.device, str],
) -> torch.Tensor:
    '''
    Convert Python int list to a 1D long torch.Tensor on `device`

    Returns:

    * If _list is not None: 1D long torch.Tensor on `device`
    * None otherwise
    '''
    return None if _list is None else torch.tensor(
        _list, dtype=torch.long, device=device)


def maybe_max(_list: Optional[list]) -> Optional[Number]:
    '''
    Returns:

    * If _list is not None: max(_list)
    * None otherwise
    '''
    return None if _list is None else max(_list)


def make_causal_mask(
    q_max_seq_len: int,
    kv_max_seq_len: int,
) -> torch.Tensor:
    '''
    Create a q_max_seq_len x kv_max_seq_len causal mask

    Arguments:
    
    * q_max_seq_len: query max seq len
    * kv_max_seq_len: key/value max seq len

    Returns:

    * 2D tensor, q_max_seq_len x kv_max_seq_len
    '''

    # Create a matrix where entry (i, j) is True if i >= j
    mask = torch.triu(torch.ones(q_max_seq_len, kv_max_seq_len), diagonal=1)
    # Replace True with float('-inf') and False with 0
    mask = mask.masked_fill(mask == 1,
                            float('-inf')).masked_fill(mask == 0, 0.0)
    return mask


def override_backend_env_variable(mpatch: pytest.MonkeyPatch,
                                  backend_name: str) -> None:
    '''
    Override the environment variable indicating the vLLM backend temporarily,
    using pytest monkeypatch to ensure that the env vars get
    reset once the test context exits.

    Arguments:

    * mpatch: pytest monkeypatch instance
    * backend_name: attention backend name to force
    '''
    mpatch.setenv(STR_BACKEND_ENV_VAR, backend_name)


def ref_masked_attention(query: torch.Tensor,
                         key: torch.Tensor,
                         value: torch.Tensor,
                         scale: float,
                         custom_mask: Optional[torch.Tensor] = None,
                         q_seq_lens: Optional[list] = None,
                         kv_seq_lens: Optional[list] = None) -> torch.Tensor:
    '''
    "Golden" masked attention reference. Supports two types of masking:

    * Basic attention mask, utilizing {q,kv}_seq_lens args to mask out
      padding elements
    * Custom attention mask, which can force an arbitrary mask tensor, i.e.
      causal

    Arguments:

    * query: batch_size x q_padded_seq_len x num_heads x head_size
    * key: batch_size x kv_padded_seq_len x num_heads x head_size
    * value: batch_size x kv_padded_seq_len x num_heads x head_size
    * scale: Attention scale factor
    * custom_mask: custom attention mask; good place to inject a causal
      attention mask
    * q_seq_lens: list of unpadded query seq_lens for each batch index
    * kv_seq_lens: list of unpadded key/value seq_lens for each batch index

    Returns:

    * Attention result, batch_size x q_padded_seq_len x num_heads x head_size
    '''

    assert q_seq_lens is not None
    assert kv_seq_lens is not None

    batch_size = query.shape[0]
    assert (len(q_seq_lens) == batch_size)
    assert (len(kv_seq_lens) == batch_size)

    attn_weights = scale * torch.einsum("bqhd,bkhd->bhqk", query, key).float()

    # Basic attention mask, derived from seq lens
    if (q_seq_lens is not None) or (kv_seq_lens is not None):
        attn_mask = torch.zeros_like(attn_weights)
        if q_seq_lens is not None:
            for bdx, plen in enumerate(q_seq_lens):
                attn_mask[bdx, :, plen:, :] = -torch.inf
        if kv_seq_lens is not None:
            for bdx, plen in enumerate(kv_seq_lens):
                attn_mask[bdx, :, :, plen:] = -torch.inf

        attn_weights = attn_weights + attn_mask.float()

    # Custom attention mask
    if custom_mask is not None:
        attn_weights = attn_weights + custom_mask.float()

    attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype)
    out = torch.einsum("bhqk,bkhd->bqhd", attn_weights, value)
    return out


def make_qkv(
    batch_size: int,
    max_q_seq_len: int,
    max_kv_seq_len: Optional[int],
    num_heads: int,
    head_size: int,
    device: Union[torch.device, str],
    force_kv_seq_lens: Optional[list[int]] = None,
    attn_type: AttentionType = AttentionType.ENCODER_DECODER,
    force_max_len: bool = False,
) -> tuple[QKVInputs, QKVInputs, QKVInputs]:
    '''
    Construct QKV test tensors for self- and cross-attention.

    Generates three query/key/value triplets:

    * "Baseline" query/key/value (for input to reference attention function)
    * "Prefill" query/key/value (last sequence offset zero'd out, for use as
      input to prefill kernel)
    * "Decode" query/key/value (only the last sequence offset  from baseline,
      for use as input to decode kernel)

    Each Q/K/V triplet is associated with a list of q seqlens and a list of k/v
    seqlens

    Arguments:

    * batch_size
    * max_q_seq_len: max query seq len
    * max_kv_seq_len: max key/value seq len
    * num_heads
    * head_size
    * is_encoder_decoder_attn: if True, query seqlen may differ from 
      key/value seqlen (as is often the case for cross-attention); 
      o/w, query/key/value seqlens match at each batch index 
      (max_kv_seq_len is unused)
    * force_kv_seq_lens: if not None, overrides kv sequence lengths
    * attn_type: encoder, decoder self, or enc/dec cross attention
    * force_max_len: if True, all query seqlens are max_q_seq_len; o/w query
      seqlens are random in [2,max_q_seq_lens]. Same for key/value seqlens
      and max_kv_seq_len, unless forced by is_encoder_decoder_attn=False
    * device: CPU or CUDA device

    Returns:

    * Overall QKVInputs structure (containing full unpacked Q/K/V tensors)
    * Prefill QKVInputs structure (containing all but the last sequence offset)
    * Decode QKVInputs structure (containing all only the last sequence offset)
    '''

    if force_max_len:
        q_seq_lens = [max_q_seq_len for _ in range(batch_size)]
    else:
        q_seq_lens = [
            random.randint(2, max_q_seq_len) for _ in range(batch_size)
        ]
    kv_seq_lens = None
    if force_kv_seq_lens is not None:
        kv_seq_lens = force_kv_seq_lens
    elif attn_type != AttentionType.ENCODER_DECODER:
        # K,V seq lens match Q for self-attention
        kv_seq_lens = q_seq_lens
    else:
        # K,V seq lens are distinct from Q seq lens & random
        assert max_kv_seq_len is not None
        if force_max_len:
            kv_seq_lens = [max_kv_seq_len] * batch_size
        else:
            kv_seq_lens = [
                random.randint(2, max_kv_seq_len) for _ in range(batch_size)
            ]

    query = torch.rand(
        (batch_size, max_q_seq_len, num_heads, head_size)).to(device)
    key = torch.rand(
        (batch_size, max_kv_seq_len, num_heads, head_size)).to(device)
    value = torch.rand(
        (batch_size, max_kv_seq_len, num_heads, head_size)).to(device)

    prefill_query = torch.zeros(
        (batch_size, max_q_seq_len, num_heads, head_size)).to(device)
    prefill_key = torch.zeros(
        (batch_size, max_kv_seq_len, num_heads, head_size)).to(device)
    prefill_value = torch.zeros(
        (batch_size, max_kv_seq_len, num_heads, head_size)).to(device)

    decode_query = torch.zeros(
        (batch_size, 1, num_heads, head_size)).to(device)
    decode_key = torch.zeros((batch_size, 1, num_heads, head_size)).to(device)
    decode_value = torch.zeros(
        (batch_size, 1, num_heads, head_size)).to(device)

    for bdx, (q_seq_len, kv_seq_len) in enumerate(zip(q_seq_lens,
                                                      kv_seq_lens)):
        query[bdx, q_seq_len:, :, :] = 0
        key[bdx, kv_seq_len:, :, :] = 0
        value[bdx, kv_seq_len:, :, :] = 0

        prefill_query[bdx,
                      0:(q_seq_len - 1), :, :] = query[bdx,
                                                       0:(q_seq_len - 1), :, :]
        prefill_key[bdx,
                    0:(kv_seq_len - 1), :, :] = key[bdx,
                                                    0:(kv_seq_len - 1), :, :]
        prefill_value[bdx, 0:(kv_seq_len -
                              1), :, :] = value[bdx, 0:(kv_seq_len - 1), :, :]

        decode_query[bdx, :, :, :] = query[bdx,
                                           (q_seq_len - 1):q_seq_len, :, :]
        decode_key[bdx, :, :, :] = key[bdx, (kv_seq_len - 1):kv_seq_len, :, :]
        decode_value[bdx, :, :, :] = value[bdx,
                                           (kv_seq_len - 1):kv_seq_len, :, :]

    prefill_q_seq_lens = [plen - 1 for plen in q_seq_lens]
    prefill_kv_seq_lens = [plen - 1 for plen in kv_seq_lens]

    decode_q_seq_lens = [1 for _ in q_seq_lens]
    decode_kv_seq_lens = [1 for _ in kv_seq_lens]

    return (
        QKVInputs(
            query,  # Overall QKV inputs
            key,
            value,
            q_seq_lens,
            kv_seq_lens),
        QKVInputs(
            prefill_query,  # Prefill subset of QKV sequences
            prefill_key,
            prefill_value,
            prefill_q_seq_lens,
            prefill_kv_seq_lens),
        QKVInputs(
            decode_query,  # Decode subset of KV sequences
            decode_key,
            decode_value,
            decode_q_seq_lens,
            decode_kv_seq_lens))


def pack_tensor(
        unpacked_tensor: torch.Tensor, seq_lens: list[int],
        device: Union[torch.device, str]) -> tuple[torch.Tensor, list[int]]:
    '''
    Pack a batch_size x padded_seq_len x num_heads x head_size tensor into an
    unpadded number_of_tokens x num_heads x head_size tensor, where
    number_of_tokens = sum(seq_lens)

    Arguments:

    * unpacked_tensor: batch_size x padded_seq_len x num_heads x head_size
    * seq_lens: list of token counts for each seq
    * device: CPU or CUDA device

    Returns

    * packed_tensor: number_of_tokens x num_heads x head_size
    * start_loc_list: start idx of each batch elt in packed_tensor; [0] +
      list(itertools.accumulate(seq_lens))
    '''

    num_tok = sum(seq_lens)
    num_heads = unpacked_tensor.shape[-2]
    head_size = unpacked_tensor.shape[-1]
    start_loc_list = [0] + list(itertools.accumulate(seq_lens))
    packed_tensor = torch.zeros((num_tok, num_heads, head_size), device=device)

    for bdx, (seq_len, start_loc) in enumerate(zip(seq_lens, start_loc_list)):

        packed_tensor[start_loc:(
            start_loc + seq_len), :, :] = unpacked_tensor[bdx, :seq_len, :, :]

    return packed_tensor, start_loc_list


def pack_qkv(qkv: QKVInputs, device: Union[torch.device,
                                           str]) -> PackedQKVInputs:
    '''
    Individually pack each of Q, K and V, each with dimensions batch_size x
    padded_seq_len x num_heads x head_size, into respective number_of_tokens x
    num_heads x head_size tensors.
    
    For Q, number_of_tokens = sum(q_seq_lens).

    For K and V, number_of_tokens = sum(kv_seq_lens)

    Arguments:

    * qkv: Unpacked (batch_size x padded_seq_len x num_heads x head_size)
           attention inputs
    * device: CPU or CUDA device

    Returns

    * Packed (number_of_tokens x num_heads x head_size) QKV inputs
      derived from unpacked inputs
    '''

    if qkv.query is None:
        packed_query = None
        q_start_loc_list = None
    else:
        packed_query, q_start_loc_list = pack_tensor(qkv.query,
                                                     qkv.q_seq_lens,
                                                     device=device)
    packed_key, kv_start_loc_list = pack_tensor(qkv.key,
                                                qkv.kv_seq_lens,
                                                device=device)
    packed_value, _ = pack_tensor(qkv.value, qkv.kv_seq_lens, device=device)
    return PackedQKVInputs(
        packed_query, packed_key, packed_value, q_start_loc_list,
        kv_start_loc_list,
        (None if q_start_loc_list is None else qkv.q_seq_lens),
        qkv.kv_seq_lens)


def make_backend(backend_name: str) -> AttentionBackend:
    '''
    Construct the backend instance determined by the backend_name string
    argument.

    "XFORMERS" -> construct xformers backend

    TODO: other backends

    Note: at time of writing the Attention wrapper automatically selects
    its own backend for Attention.forward(); so the backend instance which
    you generate with this function is not meant to be used for *running*
    inference, but rather for generating compatible metadata structures
    using backend.make_metadata()


    Returns:

    * Backend instance
    '''
    if backend_name == STR_XFORMERS_ATTN_VAL:
        # NOTE: xFormers backend cannot be imported for CPU and AMD GPUs.
        from vllm.attention.backends.xformers import XFormersBackend
        return XFormersBackend()
    elif backend_name == STR_FLASH_ATTN_VAL:
        from vllm.attention.backends.flash_attn import FlashAttentionBackend
        return FlashAttentionBackend()

    raise AssertionError(
        f"Unrecognized backend_name {backend_name} for unit test")


def _make_metadata_tensors(
    seq_lens: Optional[list[int]],
    context_lens: Optional[list[int]],
    encoder_seq_lens: Optional[list[int]],
    device: Union[torch.device, str],
) -> tuple[torch.Tensor, torch.Tensor, Any, Any, Optional[torch.Tensor],
           torch.Tensor, torch.Tensor, Optional[int]]:
    '''
    Build scalar & tensor values required to build attention metadata structure.

    Arguments:

    * seq_lens: list of token-counts for each decoder input seq
    * context_lens: list of context length values for each seq
    * encoder_seq_lens: list of token-counts for each encoder input seq
    * device: CPU or CUDA device

    Returns:

    * seq_lens_tensor: decoder seq_lens list, as tensor
    * context_lens_tensor: context_lens list, as tensor
    * max_context_len: max(context_lens)
    * max_seq_len: max(seq_lens)
    * seq_start_loc: start idx of each sequence
    * encoder_seq_lens_tensor: encoder seq_lens list, as tensor
    * encoder_seq_start_loc: start idx of each encoder sequence
    * max_encoder_seq_len: encoder seq_lens list, as tensor
    '''
    seq_lens_tensor = maybe_make_int_tensor(seq_lens, device)
    context_lens_tensor = maybe_make_int_tensor(context_lens, device)
    max_context_len = maybe_max(context_lens)
    max_seq_len = maybe_max(seq_lens)

    encoder_seq_lens_tensor = maybe_make_int_tensor(encoder_seq_lens, device)
    max_encoder_seq_len = (None if encoder_seq_lens is None else
                           max(encoder_seq_lens))

    seq_start_loc = None

    if seq_lens_tensor is not None:
        seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1,
                                    dtype=torch.int32,
                                    device=seq_lens_tensor.device)
        torch.cumsum(seq_lens_tensor,
                     dim=0,
                     dtype=seq_start_loc.dtype,
                     out=seq_start_loc[1:])

    encoder_seq_start_loc = torch.zeros(encoder_seq_lens_tensor.shape[0] + 1,
                                        dtype=torch.int32,
                                        device=encoder_seq_lens_tensor.device)
    torch.cumsum(encoder_seq_lens_tensor,
                 dim=0,
                 dtype=encoder_seq_start_loc.dtype,
                 out=encoder_seq_start_loc[1:])

    return (seq_lens_tensor, context_lens_tensor, max_context_len, max_seq_len,
            seq_start_loc, encoder_seq_lens_tensor, encoder_seq_start_loc,
            max_encoder_seq_len)


def make_kv_cache(num_blocks: int,
                  num_heads: int,
                  head_size: int,
                  block_size: int,
                  device: Union[torch.device, str],
                  backend: str,
                  default_val: float = 0.0) -> torch.Tensor:
    '''
    Create a fake KV cache.

    Arguments:

    * num_blocks: number of blocks in the KV cache
    * num_heads: number of attention heads
    * head_size: head dimension
    * block_size: number of offsets within a block
    * device: CPU or CUDA device
    * default_val: initialization value for KV cache elements

    Returns:

    * kv_cache: 2 x num_blocks x (block_size * num_heads * head_size)
    *     for backend 'XFORMERS' 
    * kv_cache: 2 x num_blocks x block_size x num_heads x head_size
    *     for backend 'FLASH_ATTN'  
    '''
    if backend == 'XFORMERS':
        kv_cache = torch.rand(
            (2, num_blocks, block_size * num_heads * head_size)).to(device)
    elif backend == 'FLASH_ATTN':
        kv_cache = torch.rand(
            (2, num_blocks, block_size, num_heads, head_size)).to(device)
    else:
        raise ValueError(
            f"Unknown backend value: '{backend}'. Expected 'XFORMERS' or "
            f"'FLASH_ATTN'.")
    if default_val is not None:
        kv_cache[:, :, :] = default_val
    return kv_cache


def _num_tokens_to_min_blocks(num_tokens: int, block_size: int) -> int:
    '''
    Compute the minimum number of blocks required to hold num_tokens tokens,
    given block_size
    '''
    return (num_tokens + block_size) // block_size


def make_empty_slot_mapping_tensor(device: Union[torch.device, str]):
    return maybe_make_long_tensor([], device)


def make_empty_block_tables_tensor(device: Union[torch.device, str]):
    return torch.tensor([], device=device)


def split_slot_mapping(slot_mapping_list: torch.Tensor, seq_lens: list[int],
                       device: Union[torch.device, str]):
    '''
    Split a slot mapping into valid prefill- and decode-phase slot mappings.

    Context:
    * Your goal is to test (1) prefill of N prompts, with prompt-lengths
      {K_i \\forall i \\in [0,N)}, followed by (2) decoding of a single token
      for all N prompts (N tokens total); the resultant sequence lengths 
      after decode would be {K_i + 1 for i \\in [0,N)}
    * The test you want to do requires (1) having the prefill slot mapping 
      for all tokens present during prefill, the number of which is 
      M = \\sum_i{K_i}, and (2) having the decode slot mapping for all N 
      decoded tokens
    
    This function consumes a single 1D slot mapping, which is the 
    concatenation of N slot mappings each of length K_i + 1 (corresponding
    to the  sequence lengths after decode), with a total length of
    P = \\sum_i{K_i + 1} = M + N

    The prefill-phase slot mapping results from excising the (K_i + 1)-th entry
    from each of the N subsequences in the slot mapping (i.e. omitting the 
    decoded token's mapping.)

    The N excised entries are appended to obtain the decode-phase slot mapping

    Arguments:

    * slot_mapping_list: Length-P 1D slot mapping (as list) reflecting all N
      post-decode sequences
    * seq_lens: list of N post-decode sequence lengths (K_i + 1 in the 
      description above)
    * device: cuda, cpu, etc.

    Returns:

    * prefill_slot_mapping: Length-M 1D slot mapping (as Tensor) 
      reflecting all N prefill prompts
    * decode_slot_mapping: Length-N 1D slot mapping (as Tensor) reflecting 
      all N decoded tokens
    '''

    prefill_slot_mapping = []
    decode_slot_mapping = []

    base_idx = 0
    for seq_len in seq_lens:
        prefill_slot_mapping.extend(slot_mapping_list[base_idx:(base_idx +
                                                                seq_len - 1)])
        decode_slot_mapping.append(slot_mapping_list[base_idx + seq_len - 1])
        base_idx += seq_len

    return (maybe_make_long_tensor(prefill_slot_mapping, device),
            maybe_make_long_tensor(decode_slot_mapping, device))


def make_block_tables_slot_mapping(
        block_size: int,
        seq_lens: list[int],
        device: Union[torch.device, str],
        block_base_addr: int = 0) -> tuple[torch.Tensor, list[int], int]:
    '''
    Construct fake block tables & slot mappings.

    For a sequence with num_tokens tokens the minimum number
    of required KV cache blocks is

    num_blocks = (num_tokens + block_size) // block_size

    Then the minimum KV cache size in blocks is

    total_cache_blocks = sum(num_blocks for all seqs) 

    Then, the blocktable mapping counts downward from

    block_base_addr + total_cache_blocks

    to

    block_base_addr
    

    The constructed block-tables and slot-mapping are sized to the
    lengths of the sequences in their entirety (as reflected by seq_lens),
    i.e. the total of prefill prompt tokens + decoded tokens.

    Arguments:

    * block_size: number of offsets per block
    * seq_lens: list of token-counts for each sequence
    * block_base_addr: the block table base address
    * device: CPU or CUDA device

    Return:

    * block_tables_tensor: block table for sequence   
    * slot_mapping_list: slot mapping for sequence
    * max_block_idx: the highest block address within this block table
    '''

    # Provision minimum number of KV cache blocks
    num_blocks_list = [
        _num_tokens_to_min_blocks(num_tokens, block_size)
        for num_tokens in seq_lens
    ]
    max_block_table_len = max(num_blocks_list)
    block_table_pad_tokens = 10

    block_tables = []
    slot_mapping_list = []
    # Compute uppermost address of block table
    total_cache_blocks = sum(num_blocks_list)
    block_base_idx = block_base_addr + total_cache_blocks
    max_block_idx = block_base_idx
    for sdx, num_tokens in enumerate(seq_lens):
        num_blocks = num_blocks_list[sdx]
        block_table = list(
            range(block_base_idx, block_base_idx - num_blocks, -1))
        for idx in range(num_tokens):
            mapping_value = (
                idx % block_size) + block_table[idx // block_size] * block_size
            slot_mapping_list.append(mapping_value)

        block_base_idx -= num_blocks
        block_tables.append(block_table)

    block_tables_tensor = make_tensor_with_pad(
        block_tables,
        max_len=max_block_table_len + block_table_pad_tokens,
        pad=0,
        dtype=torch.int,
        device=device,
    )

    return (block_tables_tensor, slot_mapping_list, max_block_idx)


def make_test_metadata(
    attn_backend: _Backend,
    is_prompt: bool,
    seq_lens: Optional[list[int]],
    decoder_test_params: Optional[PhaseTestParameters],
    device: Union[torch.device, str],
    encoder_test_params: Optional[PhaseTestParameters] = None,
    cross_test_params: Optional[PhaseTestParameters] = None
) -> AttentionMetadata:
    '''
    Construct fake attention metadata for a given test phase
    (prefill-phase or decode-phase).

    encoder_test_params and cross_test_params arguments allow encoder
    attention and enc/dec cross-attention (respectively) to use distinct
    metadata values from decoder self-attention (decoder_test_params.)
    
    if encoder_test_params and cross_test_params are None, the attention
    metadata will support decoder-only scenario.

    Assumptions:

    * No chunked prefill -> a batch is 100% prefill or 100% decode, never both

    Arguments:

    * attn_backend_name: Backend for sourcing attention kernels
    * is_prompt: prefill if True, o/w decode
    * seq_lens: list of token counts for each sequence
    * decoder_test_params: decoder self-attention test params; 
                           this function requires
                           kv_mmap (memory mapping) field
    * device: CPU or CUDA device
    * encoder_test_params: encoder attention test params;
                           this function requires encoder query
                           sequence lengths field. If None,
                           encoder query sequence lengths are
                           treated as None
    * cross_test_params: enc/dec cross-attention test params;
                         this function requires kv_mmap field.
                         If None, KV cache memory map data
                         structures are treated as None

    Return:

    * AttentionMetadata structure
    '''

    # Decoder self-attention memory mapping
    # decoder_test_params is None signals encoder-only
    # scenario, so kv_mmap is None
    kv_mmap = (None
               if decoder_test_params is None else decoder_test_params.kv_mmap)

    # This function constructs metadata assuming no chunked prefill,
    # i.e. 100% prefill tokens or 100% decode tokens
    #
    # - If is_prompt, num_prefills_or_decodes is the number of prefills
    #   and num_prefill_or_decode_tokens is the number of prefill tokens
    # - If not is_prompt, num_prefills_or_decodes is the number of decodes
    #   and num_prefill_or_decode_tokens is the number of decode tokens
    #
    # seq_lens is None signals encoder-only
    # scenario, in which case num_prefills_or_decodes and
    # num_prefill_or_decode_tokens are unused
    num_prefills_or_decodes = (None if seq_lens is None else len(seq_lens))

    num_prefill_or_decode_tokens = (None if seq_lens is None else (
        sum(seq_lens) if is_prompt else len(seq_lens)))

    # Seems for non-prefix-caching scenarios context_lens
    # is never needed
    context_lens = None

    if encoder_test_params is None:
        encoder_seq_lens = None
        num_encoder_tokens = None
    else:
        # Encoder/decoder or encoder-only models only:
        # * Extract encoder input sequence lengths
        assert encoder_test_params.packed_qkvo.packed_qkv is not None
        encoder_seq_lens = encoder_test_params.packed_qkvo.packed_qkv.q_seq_lens
        num_encoder_tokens = (None if encoder_seq_lens is None else
                              (sum(encoder_seq_lens)))

    if cross_test_params is None:
        cross_kv_mmap = None
    else:
        # Encoder/decoder or encoder-only models only:
        # * Extract *cross-attention* slot_mapping and block table
        #   (kv_mmap)
        cross_kv_mmap = cross_test_params.kv_mmap

    attn_backend_obj = make_backend(attn_backend.name)

    if is_prompt:
        # Prefill-phase scenario

        num_prefills = num_prefills_or_decodes
        num_prefill_tokens = num_prefill_or_decode_tokens
        num_decode_tokens = 0

        (
            seq_lens_tensor,
            context_lens_tensor,
            _,
            _,
            seq_start_loc,
            encoder_seq_lens_tensor,
            encoder_seq_start_loc,
            max_encoder_seq_len,
        ) = _make_metadata_tensors(seq_lens,
                                   context_lens,
                                   encoder_seq_lens,
                                   device=device)
        return attn_backend_obj.make_metadata(
            num_prefills=num_prefills,
            slot_mapping=(None if kv_mmap is None else kv_mmap.slot_mapping),
            multi_modal_placeholder_index_maps=None,
            enable_kv_scales_calculation=True,
            num_prefill_tokens=num_prefill_tokens,
            num_decode_tokens=num_decode_tokens,
            seq_lens=seq_lens,
            seq_lens_tensor=seq_lens_tensor,
            seq_start_loc=seq_start_loc,
            max_prefill_seq_len=None if seq_lens is None else max(seq_lens),
            max_decode_seq_len=0,
            context_lens_tensor=context_lens_tensor,
            block_tables=(None if kv_mmap is None else kv_mmap.block_tables),
            use_cuda_graph=False,
            num_encoder_tokens=num_encoder_tokens,
            encoder_seq_lens=encoder_seq_lens,
            encoder_seq_lens_tensor=encoder_seq_lens_tensor,
            encoder_seq_start_loc=encoder_seq_start_loc,
            max_encoder_seq_len=max_encoder_seq_len,
            cross_slot_mapping=(None if cross_kv_mmap is None else
                                cross_kv_mmap.slot_mapping),
            cross_block_tables=(None if cross_kv_mmap is None else
                                cross_kv_mmap.block_tables))

    else:  # not is_prompt
        # Decode-phase scenario

        assert kv_mmap is not None
        assert num_prefill_or_decode_tokens is not None
        assert seq_lens is not None

        num_prefills = 0
        num_prefill_tokens = 0
        num_decode_tokens = num_prefill_or_decode_tokens

        (
            seq_lens_tensor,
            context_lens_tensor,
            _,
            _,
            seq_start_loc,
            encoder_seq_lens_tensor,
            encoder_seq_start_loc,
            max_encoder_seq_len,
        ) = _make_metadata_tensors(seq_lens,
                                   context_lens,
                                   encoder_seq_lens,
                                   device=device)

        return attn_backend_obj.make_metadata(
            num_prefills=num_prefills,
            slot_mapping=kv_mmap.slot_mapping,
            multi_modal_placeholder_index_maps=None,
            enable_kv_scales_calculation=True,
            num_prefill_tokens=num_prefill_tokens,
            num_decode_tokens=num_decode_tokens,
            seq_lens=seq_lens,
            seq_lens_tensor=seq_lens_tensor,
            seq_start_loc=seq_start_loc,
            max_prefill_seq_len=0,
            max_decode_seq_len=max(seq_lens),
            max_decode_query_len=1,
            context_lens_tensor=context_lens_tensor,
            block_tables=kv_mmap.block_tables,
            use_cuda_graph=False,
            num_encoder_tokens=num_encoder_tokens,
            encoder_seq_lens=encoder_seq_lens,
            encoder_seq_lens_tensor=encoder_seq_lens_tensor,
            encoder_seq_start_loc=encoder_seq_start_loc,
            max_encoder_seq_len=max_encoder_seq_len,
            cross_slot_mapping=(None if cross_kv_mmap is None else
                                cross_kv_mmap.slot_mapping),
            cross_block_tables=(None if cross_kv_mmap is None else
                                cross_kv_mmap.block_tables))


def assert_actual_matches_ideal(test_params: PhaseTestParameters,
                                output_under_test: torch.Tensor,
                                backend: str) -> None:
    '''
    Assert that observed output matches the ideal output
    contained in the test parameters data structure.

    Arguments:

    * test_params: Test parameters including packed ideal output
    * output_under_test: actually observed output value
    '''
    ideal_output = test_params.packed_qkvo.ideal_output
    if backend == 'XFORMERS':
        torch.testing.assert_close(ideal_output,
                                   output_under_test.view_as(ideal_output))

    elif backend == 'FLASH_ATTN':
        # For FlashAttention override the accuracy thresholds to non default
        # values since we notice a higher difference between the ideal and
        # actual output.
        torch.testing.assert_close(ideal_output,
                                   output_under_test.view_as(ideal_output),
                                   atol=0.01,
                                   rtol=0.016)
    else:
        raise ValueError(
            f"Unknown backend value: '{backend}'. Expected 'XFORMERS' or "
            f"'FLASH_ATTN'.")


# Copied/modified from torch._refs.__init__.py
def fp8_allclose(
    a: TensorLikeType,
    b: TensorLikeType,
    rtol: float = 1e-05,
    atol: float = 1e-08,
    equal_nan: bool = False,
) -> bool:
    """
    Reference implementation of torch.allclose
    """
    torch._refs._check_close_args(name="torch.allclose",
                                  a=a,
                                  b=b,
                                  rtol=rtol,
                                  atol=atol)

    return bool(
        torch.all(
            torch.isclose(a.double(),
                          b.double(),
                          rtol=rtol,
                          atol=atol,
                          equal_nan=equal_nan)).item())


# Marlin MoE test utils


def stack_and_dev(tensors: list[torch.Tensor]):
    dev = tensors[0].device
    return torch.stack(tensors, dim=0).to(dev)


def compute_max_diff(output, output_ref):
    return torch.mean(torch.abs(output - output_ref)) / torch.mean(
        torch.abs(output_ref))


def torch_moe(a, w1, w2, score, topk, expert_map):
    B, D = a.shape
    a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
    out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device)
    score = torch.softmax(score, dim=-1, dtype=torch.float32)
    topk_weight, topk_ids = torch.topk(score, topk)
    topk_weight = topk_weight.view(-1)
    topk_ids = topk_ids.view(-1)
    if expert_map is not None:
        topk_ids = expert_map[topk_ids]
    for i in range(w1.shape[0]):
        mask = topk_ids == i
        if mask.sum():
            out[mask] = SiluAndMul()(
                a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose(0, 1)
    return (out.view(B, -1, w2.shape[1]) *
            topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1)


def torch_moe_single(a, w, score, topk):
    B, D = a.shape
    a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
    out = torch.zeros(B * topk, w.shape[1], dtype=a.dtype, device=a.device)
    score = torch.softmax(score, dim=-1, dtype=torch.float32)
    _, topk_ids = torch.topk(score, topk)
    topk_ids = topk_ids.view(-1)
    for i in range(w.shape[0]):
        mask = topk_ids == i
        if mask.sum():
            out[mask] = a[mask] @ w[i].transpose(0, 1)
    return (out.view(B, -1, w.shape[1])).sum(dim=1)


# A special version of op check that has a restricted default set of test_utils
# and a patched version of allclose that supports fp8 types.
def opcheck(op: Union[torch._ops.OpOverload, torch._ops.OpOverloadPacket,
                      torch._library.custom_ops.CustomOpDef],
            args: tuple[Any, ...],
            kwargs: Optional[dict[str, Any]] = None,
            *,
            test_utils: Union[str, Sequence[str]] = ALL_OPCHECK_TEST_UTILS,
            raise_exception: bool = True,
            cond: bool = True) -> dict[str, str]:
    with unittest.mock.patch('torch.allclose', new=fp8_allclose):
        return torch.library.opcheck(
            op,
            args,
            kwargs,
            test_utils=test_utils,
            raise_exception=raise_exception) if cond else {}


# For testing quantized linear kernels
def to_fp8(tensor: torch.Tensor):
    finfo = torch.finfo(torch.float8_e4m3fn)
    return torch.round(tensor.clamp(
        min=finfo.min, max=finfo.max)).to(dtype=torch.float8_e4m3fn)


def to_int8(tensor: torch.Tensor):
    return torch.round(tensor.clamp(min=-128, max=127)).to(dtype=torch.int8)


def baseline_scaled_mm(a: torch.Tensor,
                       b: torch.Tensor,
                       scale_a: torch.Tensor,
                       scale_b: torch.Tensor,
                       out_dtype: type[torch.dtype],
                       bias: Optional[torch.Tensor] = None) -> torch.Tensor:

    # We treat N-dimensional group scaling as extended numpy-style broadcasting
    # in numpy simply stretches dimensions with an extent of 1 to match the
    # the target shape by repeating the data along that dimension (broadcasting)
    # , we extend these semantics to say if the extent of a dimension in the
    # source shape is not 1 and does not match the target shape we repeat each
    # element along that dimension src_shape[dim] // target_shape[dim] times
    # example if we have:
    #       a = [[1, 2], and target_shape = (2, 4)
    #            [3, 4]]
    # then we would expand a to:
    #       a = [[1, 1, 2, 2],
    #            [3, 3, 4, 4]]
    # NOTE this function this function does not explicitly broadcast dimensions
    # with an extent of 1, since this can be done implicitly by pytorch
    def group_broadcast(t, shape):
        for i, s in enumerate(shape):
            if t.shape[i] != s and t.shape[i] != 1:
                assert s % t.shape[i] == 0
                t = t.unsqueeze(i + 1)\
                  .expand(*t.shape[:i+1], s // t.shape[i], *t.shape[i+1:])\
                  .flatten(i, i + 1)
        return t

    scale_a = group_broadcast(scale_a, a.shape)
    scale_b = group_broadcast(scale_b, b.shape)

    output = torch.mm((scale_a * a.to(dtype=torch.float32)),
                      (scale_b * b.to(dtype=torch.float32))).to(out_dtype)

    if bias is not None:
        output = output + bias

    return output
